import torch
from tqdm import tqdm

from .coteaching import CoTeachingTrainer


class JoCoR(CoTeachingTrainer):
    def __init__(
        self,
        config,
        model,
        model2,
        logger,
        train_set,
        test_set,
        criterion,
        optimizer,
        num_gradual=10,
        forget_rate=0.2,
        exponet=1,
        mom1=0.9,
        mom2=0.1,
        co_lambda=0.5,
        epoch_decay_start=80,
        learning_rate=0.001,
        scheduler=None,
        val_set=None,
    ):
        super().__init__(
            config,
            model,
            model2,
            logger,
            train_set,
            test_set,
            criterion,
            optimizer,
            criterion,
            optimizer,
            num_gradual,
            forget_rate,
            exponet,
            mom1,
            mom2,
            epoch_decay_start,
            scheduler,
            val_set,
        )
        self.learning_rate = learning_rate
        self.num_gradual = num_gradual
        self.co_lambda = co_lambda
        self.alpha_plan = [self.config["train"]["learning_rate"]] * self.epoch
        self.beta1_plan = [self.mom1] * self.epoch
        self.optimizer = torch.optim.Adam(
            list(self.model.parameters()) + list(self.model2.parameters()),
            lr=self.learning_rate,
        )

        for i in range(self.epoch_decay_start, self.epoch):
            self.alpha_plan[i] = (
                float(self.epoch - i)
                / (self.epoch - self.epoch_decay_start)
                * self.config["train"]["learning_rate"]
            )
            self.beta1_plan[i] = mom2

    def evaluate(self, val=True, second_model=False):
        if second_model:
            try:
                model_test = self.model2
            except Exception as e:
                print("There is no second model. Still testing the first model.")
                model_test = self.model
        else:
            model_test = self.model
        model_test.eval()
        correct, total_num, total_loss = 0.0, 0.0, 0.0
        loader = self.val_loader if val else self.test_loader
        evaluate_type = "Val" if val else "Test"
        for (
            iter,
            data,
        ) in enumerate(loader):
            inputs, labels, attributes, idx = self.prepare_data(data)
            with torch.no_grad():
                outputs = model_test(inputs)
            outputs = outputs.detach().cpu()
            labels = labels.detach().cpu()
            correct += (outputs.argmax(1) == labels).sum().item()
            total_num += labels.size(0)
        acc = correct / total_num * 100
        print(f"{evaluate_type} Acc: {acc:.4f}")
        return acc

    def run(self):
        print("==> Start training..")
        best_acc_1, best_acc_2 = 0.0, 0.0
        for cur_epoch in range(self.epoch):
            self.model.train()
            self.model2.train()
            self.adjust_learning_rate(cur_epoch)

            train_total_1, train_total_2 = 0.0, 0.0
            epoch_loss_1, epoch_loss_2 = 0.0, 0.0
            train_correct_1, train_correct_2 = 0.0, 0.0
            with tqdm(self.train_loader, unit="batch") as tepoch:
                for data in tepoch:
                    tepoch.set_description(f"Epoch {cur_epoch}")
                    inputs, labels, attributes, idx = self.prepare_data(data)
                    logits_1 = self.model(inputs)
                    prec_1 = self.accuracy(logits_1, labels, topk=(1,))
                    train_total_1 += 1
                    train_correct_1 += prec_1[0]
                    logits_2 = self.model2(inputs)
                    prec_2 = self.accuracy(logits_2, labels, topk=(1,))
                    train_total_2 += 1
                    train_correct_2 += prec_2[0]

                    loss_1, loss_2 = self.criterion(
                        logits_1,
                        logits_2,
                        labels,
                        self.rate_schedule[cur_epoch],
                        self.co_lambda,
                    )

                    self.optimizer.zero_grad()
                    loss_1.backward()
                    self.optimizer.step()

                    tepoch.set_postfix(
                        loss_1=loss_1.item(),
                        accuracy_1=float(train_correct_1) / train_total_1,
                        loss_2=loss_2.item(),
                        accuracy_2=float(train_correct_2) / train_total_2,
                        lr=self.get_lr(),
                    )
                    self.global_iter += 1
                    epoch_loss_1 += loss_1
                    epoch_loss_2 += loss_2

                    if (
                        self.global_iter % self.config["general"]["logger"]["frequency"]
                        == 0
                    ):
                        self.logger.info(
                            f"[{cur_epoch}]/[{self.epoch}], Global Iter: {self.global_iter}, Loss 1: {loss_1:.4f}, Loss 2: {loss_2:.4f}, Acc 1: {float(train_correct_1) / train_total_1:.4f}, Acc 2: {float(train_correct_2) / train_total_2:.4f}, lr: {self.get_lr():.6f}",
                            {
                                "cur_epoch": cur_epoch,
                                "iter": self.global_iter,
                                "loss 1": loss_1.item(),
                                "loss 2": loss_2.item(),
                                "Accuracy 1": float(train_correct_1) / train_total_1,
                                "Accuracy 2": float(train_correct_2) / train_total_2,
                                "lr": self.get_lr(),
                            },
                        )

                epoch_loss_1 /= train_total_1
                epoch_loss_2 /= train_total_2
                if self.val_set:
                    _ = self.evaluate(val=True)
                    _ = self.evaluate(val=True, second_model=True)
                test_acc_1 = self.evaluate(val=False)
                test_acc_2 = self.evaluate(val=False, second_model=True)

                if test_acc_1 > best_acc_1:
                    best_acc_1 = test_acc_1
                    self.save_best_model()

                if test_acc_2 > best_acc_2:
                    best_acc_2 = test_acc_2
                    self.save_best_model(second_model=True)

                print(
                    f"Epoch: {cur_epoch}, Loss 1: {epoch_loss_1:.6f}, Loss 2: {epoch_loss_2:.6f}, Train Acc 1: {(float(train_correct_1) / train_total_1):.4f}, Train Acc 2: {(float(train_correct_2) / train_total_2):.4f}, Test Acc 1: {test_acc_1:.4f}, Test Acc 2: {test_acc_2:.4f}, Best Test Acc 1: {best_acc_1:.4f}, Best Test Acc 2: {best_acc_2:.4f}"
                )
                self.logger.info(
                    f"Epoch: {cur_epoch}, Loss 1: {epoch_loss_1:.6f}, Loss 2: {epoch_loss_2:.6f}, Train Acc 1: {(float(train_correct_1) / train_total_1):.4f}, Train Acc 2: {(float(train_correct_2) / train_total_2):.4f}, Test Acc 1: {test_acc_1:.4f}, Test Acc 2: {test_acc_2:.4f}, Best Test Acc 1: {best_acc_1:.4f}, Best Test Acc 2: {best_acc_2:.4f}",
                    {
                        "test_epoch": cur_epoch,
                        "loss 1": epoch_loss_1.item(),
                        "loss 2": epoch_loss_2.item(),
                        "Train Acc 1": (float(train_correct_1) / train_total_1),
                        "Train Acc 2": (float(train_correct_2) / train_total_2),
                        "Test Acc 1": test_acc_1,
                        "Test Acc 2": test_acc_2,
                        "Best Test Acc 1": best_acc_1,
                        "Best Test Acc 2": best_acc_2,
                    },
                )
            if cur_epoch % self.save_every_epoch == 0:
                self.save_model(f"{cur_epoch}")
                self.save_model(f"{cur_epoch}", second_model=True)
            self.save_last_model()
            self.save_last_model(second_model=True)
